from prompts import *
from langchain.agents.initialize import initialize_agent
from langchain.chains.conversation.memory import ConversationBufferMemory
from langchain.llms.openai import OpenAI
import gradio as gr
import re
from tools import bar_tool, line_tool, pie_tool, scatter_tool
from llms.tk.tkllm import CustomLLM


class ConversationBot:
    def __init__(self):
        self.tools = [line_tool, bar_tool, pie_tool, scatter_tool]
        print(f"All the Available Tools: {self.tools}")
        self.llm = CustomLLM()
        self.memory = ConversationBufferMemory(memory_key="chat_history", output_key='output')

    def init_agent(self, lang):
        self.memory.clear()  # clear previous history
        if lang == 'English':
            PREFIX, FORMAT_INSTRUCTIONS, SUFFIX = CHARTGPT_PREFIX, CHARTGPT_FORMAT_INSTRUCTIONS, CHARTGPT_SUFFIX
            place = "Enter text and press enter"
            label_clear = "Clear"
        else:
            PREFIX, FORMAT_INSTRUCTIONS, SUFFIX = CHARTGPT_PREFIX_CN, CHARTGPT_FORMAT_INSTRUCTIONS_CN, CHARTGPT_SUFFIX_CN
            place = "输入文字并回车"
            label_clear = "清除"
        self.agent = initialize_agent(
            self.tools,
            self.llm,
            agent="conversational-react-description",
            verbose=True,
            memory=self.memory,
            return_intermediate_steps=True,
            agent_kwargs={'prefix': PREFIX, 'format_instructions': FORMAT_INSTRUCTIONS,
                          'suffix': SUFFIX}, )
        print("init bot finished!")
        return gr.update(visible=True), gr.update(visible=False), gr.update(placeholder=place), gr.update(
            value=label_clear)

    def run_text(self, text, state):
        print(self.agent.memory)
        res = self.agent({"input": text.strip()})
        print("res:", res)
        res['output'] = res['output'].replace("\\", "/")
        response = re.sub('(image/[-\w]*.png)', lambda m: f'![](file={m.group(0)})*{m.group(0)}*', res['output'])
        state = state + [(text, response)]
        print(f"\nProcessed run_text, Input text: {text}\nCurrent state: {state}\n")
        return state, state


if __name__ == '__main__':
    bot = ConversationBot()
    with gr.Blocks(css="#chatbot .overflow-y-auto{height:500px}") as demo:
        lang = gr.Radio(choices=['Chinese', 'English'], value=None, label='Language')
        chatbot = gr.Chatbot(elem_id="chatbot", label="ChartGPT")
        state = gr.State([])
        with gr.Row(visible=False) as input_raws:
            with gr.Column(scale=0.7):
                txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter, or upload an image").style(
                    container=False)
            with gr.Column(scale=0.15, min_width=0):
                clear = gr.Button("Clear")

        lang.change(bot.init_agent, [lang], [input_raws, lang, txt, clear])
        txt.submit(bot.run_text, [txt, state], [chatbot, state])
        txt.submit(lambda: "", None, txt)
        clear.click(bot.memory.clear)
        clear.click(lambda: [], None, chatbot)
        clear.click(lambda: [], None, state)
    demo.launch(server_name="0.0.0.0", server_port=7861)
